{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nF07sDgJWHQy"
   },
   "outputs": [],
   "source": [
    "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BRcOouCyWHQ2"
   },
   "source": [
    "# Deform a source mesh to form a target mesh using 3D loss functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HfwwW9HqtuvQ"
   },
   "source": [
    "In this tutorial, we learn to deform an initial generic shape (e.g. sphere) to fit a target shape.\n",
    "\n",
    "We will cover: \n",
    "\n",
    "- How to **load a mesh** from an `.obj` file\n",
    "- How to use the PyTorch3D **Meshes** datastructure\n",
    "- How to use 4 different PyTorch3D **mesh loss functions**\n",
    "- How to set up an **optimization loop**\n",
    "\n",
    "\n",
    "Starting from a sphere mesh, we learn the offset to each vertex in the mesh such that\n",
    "the predicted mesh is closer to the target mesh at each optimization step. To achieve this we minimize:\n",
    "\n",
    "+ `chamfer_distance`, the distance between the predicted (deformed) and target mesh, defined as the chamfer distance between the set of pointclouds resulting from **differentiably sampling points** from their surfaces. \n",
    "\n",
    "However, solely minimizing the chamfer distance between the predicted and the target mesh will lead to a non-smooth shape (verify this by setting  `w_chamfer=1.0` and all other weights to `0.0`). \n",
    "\n",
    "We enforce smoothness by adding **shape regularizers** to the objective. Namely, we add:\n",
    "\n",
    "+ `mesh_edge_length`, which minimizes the length of the edges in the predicted mesh.\n",
    "+ `mesh_normal_consistency`, which enforces consistency across the normals of neighboring faces.\n",
    "+ `mesh_laplacian_smoothing`, which is the laplacian regularizer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "P-h1ji4dWHQ5"
   },
   "source": [
    "## 0. Install and Import modules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_qkuyhyTeRyM"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "need_pytorch3d=False\n",
    "try:\n",
    "    import pytorch3d\n",
    "except ModuleNotFoundError:\n",
    "    need_pytorch3d=True\n",
    "if need_pytorch3d:\n",
    "    if torch.__version__.startswith(\"1.9\") and sys.platform.startswith(\"linux\"):\n",
    "        # We try to install PyTorch3D via a released wheel.\n",
    "        version_str=\"\".join([\n",
    "            f\"py3{sys.version_info.minor}_cu\",\n",
    "            torch.version.cuda.replace(\".\",\"\"),\n",
    "            f\"_pyt{torch.__version__[0:5:2]}\"\n",
    "        ])\n",
    "        !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
    "    else:\n",
    "        # We try to install PyTorch3D from source.\n",
    "        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n",
    "        !tar xzf 1.10.0.tar.gz\n",
    "        os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n",
    "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ylbZGXYBtuvB"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from pytorch3d.io import load_obj, save_obj\n",
    "from pytorch3d.structures import Meshes\n",
    "from pytorch3d.utils import ico_sphere\n",
    "from pytorch3d.ops import sample_points_from_meshes\n",
    "from pytorch3d.loss import (\n",
    "    chamfer_distance, \n",
    "    mesh_edge_loss, \n",
    "    mesh_laplacian_smoothing, \n",
    "    mesh_normal_consistency,\n",
    ")\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "%matplotlib notebook \n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['savefig.dpi'] = 80\n",
    "mpl.rcParams['figure.dpi'] = 80\n",
    "\n",
    "# Set the device\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    print(\"WARNING: CPU only, this will be slow!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yT1JTXu1WHQ_"
   },
   "source": [
    "## 1. Load an obj file and create a Meshes object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the target 3D model of a dolphin. It will be saved locally as a file called `dolphin.obj`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "oFNkB6nQWZSw",
    "outputId": "c1bbe6e2-a4ea-4113-d53d-1cb1ece130f1"
   },
   "outputs": [],
   "source": [
    "!wget https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dz0imH-ltuvS"
   },
   "outputs": [],
   "source": [
    "# Load the dolphin mesh.\n",
    "trg_obj = os.path.join('dolphin.obj')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rbyRhI8ituvW"
   },
   "outputs": [],
   "source": [
    "# We read the target 3D model using load_obj\n",
    "verts, faces, aux = load_obj(trg_obj)\n",
    "\n",
    "# verts is a FloatTensor of shape (V, 3) where V is the number of vertices in the mesh\n",
    "# faces is an object which contains the following LongTensors: verts_idx, normals_idx and textures_idx\n",
    "# For this tutorial, normals and textures are ignored.\n",
    "faces_idx = faces.verts_idx.to(device)\n",
    "verts = verts.to(device)\n",
    "\n",
    "# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). \n",
    "# (scale, center) will be used to bring the predicted mesh to its original center and scale\n",
    "# Note that normalizing the target mesh, speeds up the optimization but is not necessary!\n",
    "center = verts.mean(0)\n",
    "verts = verts - center\n",
    "scale = max(verts.abs().max(0)[0])\n",
    "verts = verts / scale\n",
    "\n",
    "# We construct a Meshes structure for the target mesh\n",
    "trg_mesh = Meshes(verts=[verts], faces=[faces_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6BxDTpB2WHRH"
   },
   "outputs": [],
   "source": [
    "# We initialize the source shape to be a sphere of radius 1\n",
    "src_mesh = ico_sphere(4, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dYWDl4VGWHRK"
   },
   "source": [
    "###  Visualize the source and target meshes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "482YycLHWHRL"
   },
   "outputs": [],
   "source": [
    "def plot_pointcloud(mesh, title=\"\"):\n",
    "    # Sample points uniformly from the surface of the mesh.\n",
    "    points = sample_points_from_meshes(mesh, 5000)\n",
    "    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    \n",
    "    fig = plt.figure(figsize=(5, 5))\n",
    "    ax = Axes3D(fig)\n",
    "    ax.scatter3D(x, z, -y)\n",
    "    ax.set_xlabel('x')\n",
    "    ax.set_ylabel('z')\n",
    "    ax.set_zlabel('y')\n",
    "    ax.set_title(title)\n",
    "    ax.view_init(190, 30)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 765
    },
    "colab_type": "code",
    "id": "UoGcflJ_WHRO",
    "outputId": "b9a2d699-2c68-4696-9dff-d30eea7a0fb0"
   },
   "outputs": [],
   "source": [
    "# %matplotlib notebook\n",
    "plot_pointcloud(trg_mesh, \"Target mesh\")\n",
    "plot_pointcloud(src_mesh, \"Source mesh\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8uzMiTUSWHRS"
   },
   "source": [
    "## 3. Optimization loop "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Sc-3M17Ltuvh"
   },
   "outputs": [],
   "source": [
    "# We will learn to deform the source mesh by offsetting its vertices\n",
    "# The shape of the deform parameters is equal to the total number of vertices in src_mesh\n",
    "deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0BtSUfMYtuvl"
   },
   "outputs": [],
   "source": [
    "# The optimizer\n",
    "optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "12fdcbc799cc4da899d889d0399616c2",
      "0bd231c2134e4127a3756807317d6aae",
      "23804ad243d44cecbff89ab0b7f40c7e",
      "be25dd06faf04bf29733cc16deefb189",
      "283601ac2fe54ecc8716aed8842a5dd2",
      "6e2ff75105a74afbb4ed3fafd414e16f",
      "5462de8f68be408d98a6a495e630f448",
      "6e1e9eb164434a06b7b1bc73e4eb4fcd"
     ]
    },
    "colab_type": "code",
    "id": "9DAjqI9Atuvp",
    "outputId": "d59e959b-8616-40fe-aec4-5b09b27e325f"
   },
   "outputs": [],
   "source": [
    "# Number of optimization steps\n",
    "Niter = 2000\n",
    "# Weight for the chamfer loss\n",
    "w_chamfer = 1.0 \n",
    "# Weight for mesh edge loss\n",
    "w_edge = 1.0 \n",
    "# Weight for mesh normal consistency\n",
    "w_normal = 0.01 \n",
    "# Weight for mesh laplacian smoothing\n",
    "w_laplacian = 0.1 \n",
    "# Plot period for the losses\n",
    "plot_period = 250\n",
    "loop = tqdm(range(Niter))\n",
    "\n",
    "chamfer_losses = []\n",
    "laplacian_losses = []\n",
    "edge_losses = []\n",
    "normal_losses = []\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "for i in loop:\n",
    "    # Initialize optimizer\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Deform the mesh\n",
    "    new_src_mesh = src_mesh.offset_verts(deform_verts)\n",
    "    \n",
    "    # We sample 5k points from the surface of each mesh \n",
    "    sample_trg = sample_points_from_meshes(trg_mesh, 5000)\n",
    "    sample_src = sample_points_from_meshes(new_src_mesh, 5000)\n",
    "    \n",
    "    # We compare the two sets of pointclouds by computing (a) the chamfer loss\n",
    "    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)\n",
    "    \n",
    "    # and (b) the edge length of the predicted mesh\n",
    "    loss_edge = mesh_edge_loss(new_src_mesh)\n",
    "    \n",
    "    # mesh normal consistency\n",
    "    loss_normal = mesh_normal_consistency(new_src_mesh)\n",
    "    \n",
    "    # mesh laplacian smoothing\n",
    "    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method=\"uniform\")\n",
    "    \n",
    "    # Weighted sum of the losses\n",
    "    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian\n",
    "    \n",
    "    # Print the losses\n",
    "    loop.set_description('total_loss = %.6f' % loss)\n",
    "    \n",
    "    # Save the losses for plotting\n",
    "    chamfer_losses.append(float(loss_chamfer.detach().cpu()))\n",
    "    edge_losses.append(float(loss_edge.detach().cpu()))\n",
    "    normal_losses.append(float(loss_normal.detach().cpu()))\n",
    "    laplacian_losses.append(float(loss_laplacian.detach().cpu()))\n",
    "    \n",
    "    # Plot mesh\n",
    "    if i % plot_period == 0:\n",
    "        plot_pointcloud(new_src_mesh, title=\"iter: %d\" % i)\n",
    "        \n",
    "    # Optimization step\n",
    "    loss.backward()\n",
    "    optimizer.step()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VGcZsvWBWHRc"
   },
   "source": [
    "## 4. Visualize the loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 374
    },
    "colab_type": "code",
    "id": "baXvAo1yWHRd",
    "outputId": "11ebe2ad-4352-4492-bd67-e6a3c95adc85"
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(13, 5))\n",
    "ax = fig.gca()\n",
    "ax.plot(chamfer_losses, label=\"chamfer loss\")\n",
    "ax.plot(edge_losses, label=\"edge loss\")\n",
    "ax.plot(normal_losses, label=\"normal loss\")\n",
    "ax.plot(laplacian_losses, label=\"laplacian loss\")\n",
    "ax.legend(fontsize=\"16\")\n",
    "ax.set_xlabel(\"Iteration\", fontsize=\"16\")\n",
    "ax.set_ylabel(\"Loss\", fontsize=\"16\")\n",
    "ax.set_title(\"Loss vs iterations\", fontsize=\"16\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y9vSKErDWHRg"
   },
   "source": [
    "## 5. Save the predicted mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "krikJzrLtuvw"
   },
   "outputs": [],
   "source": [
    "# Fetch the verts and faces of the final predicted mesh\n",
    "final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)\n",
    "\n",
    "# Scale normalize back to the original target size\n",
    "final_verts = final_verts * scale + center\n",
    "\n",
    "# Store the predicted mesh using save_obj\n",
    "final_obj = os.path.join('./', 'final_model.obj')\n",
    "save_obj(final_obj, final_verts, final_faces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "quR1DVAcWHRk"
   },
   "source": [
    "## 6. Conclusion \n",
    "\n",
    "In this tutorial we learnt how to load a mesh from an obj file, initialize a PyTorch3D datastructure called **Meshes**, set up an optimization loop and use four different PyTorch3D mesh loss functions. "
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "colab": {
   "name": "deform_source_mesh_to_target_mesh.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
